import torch
import math
import random
import numpy as np
from Visualization import Visualization, Compute, funcplus, funcaverage

#NumFunc = 30
x_0_s = torch.arange(-1, 10.01, 0.1, requires_grad=True)
x_1_s = torch.arange(-1, 10.01, 0.1, requires_grad=True)
x_s = torch.meshgrid(x_0_s, x_1_s)
scale = 17.5 * 5
intensity = 0.001
CritPoints_x = torch.arange(1.0, 7.01, 3, requires_grad=False)
CritPoints_y = torch.arange(1.0, 7.01, 3, requires_grad=False)
CritPoints = torch.cartesian_prod(CritPoints_x, CritPoints_y)
PositiveRate = torch.tensor(0.35, requires_grad=False)
HeightMean = 1.0
HeightSTD = 1.0
#print(x_s)
sigma_1 = [1.3, 1.3]
#height_1 = 1.0
sigma_2 = [0.9, 0.9]
height_2 = -0.3

def NewLossFunction(sigma_1, mu_1, height_1, sigma_2, mu_2, height_2, intensity, center, scale):
    def newfunc(x):
        '''
        minimum = torch.min(height_1 * torch.exp(
            -((x_s[0] - mu_1[0]) ** 2 / (2 * sigma_1[0] ** 2) + (x_s[1] - mu_1[1]) ** 2 / (2 * sigma_1[1] ** 2))) +
                          height_2 * torch.exp(
                    -((x_s[0] - mu_2[0]) ** 2 / (2 * sigma_2[0] ** 2) + (x_s[1] - mu_2[1]) ** 2 / (
                                2 * sigma_2[1] ** 2))) + baseline +
                          intensity * ((x_s[0] - center[0]) ** 2 + (x_s[1] - center[1]) ** 2))
        '''
        BaseLocation = mu_2
        baseline = height_1 * torch.exp(
            -((BaseLocation[0] - mu_1[0]) ** 2 / (2 * sigma_1[0] ** 2) + (BaseLocation[1] - mu_1[1]) ** 2 / (2 * sigma_1[1] ** 2))) + height_2 * torch.exp(
                    -((BaseLocation[0] - mu_2[0]) ** 2 / (2 * sigma_2[0] ** 2) + (BaseLocation[1] - mu_2[1]) ** 2 / (
                                2 * sigma_2[1] ** 2))) + intensity * ((BaseLocation[0] - center[0]) ** 2 + (BaseLocation[1] - center[1]) ** 2)
        output = scale * (height_1 * torch.exp(
            -((x[0] - mu_1[0]) ** 2 / (2 * sigma_1[0] ** 2) + (x[1] - mu_1[1]) ** 2 / (2 * sigma_1[1] ** 2))) +
                          height_2 * torch.exp(
                    -((x[0] - mu_2[0]) ** 2 / (2 * sigma_2[0] ** 2) + (x[1] - mu_2[1]) ** 2 / (
                                2 * sigma_2[1] ** 2))) +
                          intensity * ((x[0] - center[0]) ** 2 + (x[1] - center[1]) ** 2) - baseline)
        #output = scale * ((height_1 * 1.0 / (2.0 * math.pi * sigma_1[0] * sigma_1[1])) * torch.exp(
        #-((x[0] - mu_1[0]) ** 2 / (2 * sigma_1[0] ** 2) + (x[1] - mu_1[1]) ** 2 / (2 * sigma_1[1] ** 2))) +
        #(height_2 * 1.0 / (2.0 * math.pi * sigma_2[0] * sigma_2[1])) * torch.exp(
        #-((x[0] - mu_2[0]) ** 2 / (2 * sigma_2[0] ** 2) + (x[1] - mu_2[1]) ** 2 / (2 * sigma_2[1] ** 2))) + baseline)
        return output
    return newfunc

#newfunc1 = NewLossFunction(sigma_1=[1.0, 5.0], mu_1=[2.1, 1.2], height_1=5.5, sigma_2=[1.0, 1.0], mu_2=[0, 0],
#                                 height_2=0.0, baseline=5, scale=scale)

def ConstructLossFunctions(NumFunc, SpuriousCritPoints, TrueCritPoint):
    LossFunctions = []
    for i in range(NumFunc):
        mu_1 = random.choices(SpuriousCritPoints, k=1)[0]
        #print(mu_1)
        height_1 = (2 * torch.bernoulli(PositiveRate) - 1) * torch.normal(mean=HeightMean, std=HeightSTD, size=[1])
        newfunc = NewLossFunction(sigma_1=sigma_1, mu_1=mu_1,
                                  height_1=height_1, sigma_2=sigma_2, mu_2=TrueCritPoint, height_2=height_2,
                                  intensity=intensity, center=TrueCritPoint, scale=scale)
        LossFunctions.append(newfunc)
    return LossFunctions

#LossFunctions = [func1, func2, func3, func4, func5, func6, func7, func8]
#LossFunctions = [newfunc1]
#random.seed(RandomSeed)
#torch.manual_seed(RandomSeed)
#LossFunctions = ConstructLossFunctions(NumFunc=NumFunc, SpuriousCritPoints=CritPoints[:-1], TrueCritPoint=CritPoints[-1])
#Visualization(x_s, [funcaverage([LossFunctions[0], LossFunctions[0], LossFunctions[3], LossFunctions[3], LossFunctions[4], LossFunctions[5]])], 'Average')
#Visualization(x_s, LossFunctions, 'LossFunctions')
#Visualization(x_s, [funcaverage(LossFunctions)], 'LossFunctionsAverage')